#!/usr/bin/env python3
"""
Violence Perception Index (VPI) - Inference Pipeline (SQLite Version)
======================================================================

This script applies fine-tuned Spanish transformer models to score
YouTube comments for violence-related discourse, reading directly from
SQLite databases.

Supports:
- Binary classification (violence / no violence)
- Regression (violence intensity 0-10)
- Batch processing for millions of comments
- Memory-efficient chunked processing from SQLite
- CSV and SQLite output options

Usage:
    # Score all comments from SQLite databases
    python inference_violence_scorer_sqlite.py \
        --db_dir ../data \
        --output_file ./scored_comments.csv \
        --model_dir ./trained_models
    
    # Score with specific databases
    python inference_violence_scorer_sqlite.py \
        --databases ../data/202001_202212_02_mx_youtube_data.db \
                    ../data/202301_202312_02_mx_youtube_data.db \
        --output_file ./scored_comments.csv \
        --model_dir ./trained_models

    # Write results back to SQLite
    python inference_violence_scorer_sqlite.py \
        --db_dir ../data \
        --output_db ./vpi_results.db \
        --model_dir ./trained_models

Requirements:
    pip install transformers torch pandas tqdm

Author: Francesco Bailo
Date: January 2026
"""

import os
import glob
import sqlite3
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import List, Optional, Union, Generator
from tqdm import tqdm

import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    pipeline
)


# =============================================================================
# CONFIGURATION
# =============================================================================

DEFAULT_TEXT_COL = 'text_original'
DEFAULT_ID_COL = 'comment_id'

# Default database pattern for Mexico YouTube data
DEFAULT_DB_PATTERN = "*_mx_youtube_data.db"

# Chunk size for reading from database
DB_CHUNK_SIZE = 10000

# Batch sizes for different memory configurations
BATCH_SIZES = {
    'small': 32,      # ~4GB VRAM or CPU
    'medium': 64,     # ~8GB VRAM
    'large': 128,     # ~12GB VRAM
    'xlarge': 256     # ~24GB+ VRAM
}


# =============================================================================
# MODEL LOADING
# =============================================================================

class ViolenceScorer:
    """
    Wrapper class for violence scoring using fine-tuned models.
    
    Handles both binary classification and regression scoring.
    """
    
    def __init__(
        self,
        model_dir: str,
        device: Optional[str] = None,
        batch_size: int = 64
    ):
        """
        Initialize the scorer with fine-tuned models.
        
        Args:
            model_dir: Directory containing trained models
            device: Device to use ('cuda', 'mps', 'cpu', or None for auto-detect)
            batch_size: Batch size for inference
        """
        self.model_dir = model_dir
        self.batch_size = batch_size
        
        # Auto-detect device
        if device is None:
            if torch.cuda.is_available():
                self.device = 0
                self.device_type = 'cuda'
            elif torch.backends.mps.is_available():
                self.device = 'mps'
                self.device_type = 'mps'
            else:
                self.device = -1
                self.device_type = 'cpu'
        elif device == 'cuda':
            self.device = 0
            self.device_type = 'cuda'
        elif device == 'mps':
            self.device = 'mps'
            self.device_type = 'mps'
        else:
            self.device = -1
            self.device_type = 'cpu'
        
        print(f"Using device: {self.device_type.upper()}")
        
        # Load models
        self.binary_model = None
        self.regression_model = None
        self.tokenizer = None
        
        self._load_models()
    
    def _load_models(self):
        """Load fine-tuned models from disk."""
        
        # Check for binary classifier
        binary_path = os.path.join(self.model_dir, "binary_classifier_final")
        if os.path.exists(binary_path):
            print(f"Loading binary classifier from: {binary_path}")
            
            # For MPS, use device='mps', for CUDA use device=0, for CPU use device=-1
            if self.device_type == 'mps':
                pipe_device = 'mps'
            elif self.device_type == 'cuda':
                pipe_device = 0
            else:
                pipe_device = -1
            
            self.binary_model = pipeline(
                "text-classification",
                model=binary_path,
                tokenizer=binary_path,
                device=pipe_device,
                batch_size=self.batch_size,
                truncation=True,
                max_length=128
            )
            self.tokenizer = AutoTokenizer.from_pretrained(binary_path)
            print("  ✓ Binary classifier loaded")
        else:
            print(f"  ✗ Binary classifier not found at {binary_path}")
        
        # Check for regression model
        regression_path = os.path.join(self.model_dir, "regression_model_final")
        if os.path.exists(regression_path):
            print(f"Loading regression model from: {regression_path}")
            
            # Load model and tokenizer manually for regression
            self.regression_tokenizer = AutoTokenizer.from_pretrained(regression_path)
            self.regression_model_raw = AutoModelForSequenceClassification.from_pretrained(
                regression_path
            )
            
            # Move to appropriate device
            if self.device_type == 'cuda':
                self.regression_model_raw = self.regression_model_raw.cuda()
            elif self.device_type == 'mps':
                self.regression_model_raw = self.regression_model_raw.to('mps')
            
            self.regression_model_raw.eval()
            print("  ✓ Regression model loaded")
        else:
            print(f"  ✗ Regression model not found at {regression_path}")
    
    def score_binary(self, texts: List[str]) -> List[dict]:
        """
        Score texts for binary violence classification.
        
        Args:
            texts: List of comment texts
            
        Returns:
            List of dicts with 'label' and 'score' keys
        """
        if self.binary_model is None:
            raise ValueError("Binary classifier not loaded")
        
        # Filter out empty/null texts
        valid_texts = [str(t) if pd.notna(t) else "" for t in texts]
        
        results = self.binary_model(valid_texts)
        
        return results
    
    def score_regression(self, texts: List[str]) -> List[float]:
        """
        Score texts for violence intensity (0-10).
        
        Args:
            texts: List of comment texts
            
        Returns:
            List of intensity scores
        """
        if self.regression_model_raw is None:
            raise ValueError("Regression model not loaded")
        
        scores = []
        
        # Process in batches
        for i in range(0, len(texts), self.batch_size):
            batch_texts = texts[i:i+self.batch_size]
            
            # Handle empty/null texts
            batch_texts = [str(t) if pd.notna(t) else "" for t in batch_texts]
            
            # Tokenize
            inputs = self.regression_tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=128,
                return_tensors="pt"
            )
            
            # Move to appropriate device
            if self.device_type == 'cuda':
                inputs = {k: v.cuda() for k, v in inputs.items()}
            elif self.device_type == 'mps':
                inputs = {k: v.to('mps') for k, v in inputs.items()}
            
            # Inference
            with torch.no_grad():
                outputs = self.regression_model_raw(**inputs)
                batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
            
            # Handle single sample case
            if batch_scores.ndim == 0:
                batch_scores = [batch_scores.item()]
            else:
                batch_scores = batch_scores.tolist()
            
            scores.extend(batch_scores)
        
        # Clip to valid range
        scores = [max(0, min(10, s)) for s in scores]
        
        return scores
    
    def score_dataframe(
        self,
        df: pd.DataFrame,
        text_col: str = DEFAULT_TEXT_COL,
        include_binary: bool = True,
        include_regression: bool = True,
        show_progress: bool = True
    ) -> pd.DataFrame:
        """
        Score all comments in a DataFrame.
        
        Args:
            df: DataFrame with comments
            text_col: Column name containing text
            include_binary: Whether to include binary predictions
            include_regression: Whether to include regression scores
            show_progress: Whether to show progress bar
            
        Returns:
            DataFrame with added score columns
        """
        df = df.copy()
        texts = df[text_col].tolist()
        n_texts = len(texts)
        
        # Binary classification
        if include_binary and self.binary_model is not None:
            binary_results = []
            
            if show_progress:
                pbar = tqdm(total=n_texts, desc="Binary classification")
            
            for i in range(0, n_texts, self.batch_size):
                batch = texts[i:i+self.batch_size]
                results = self.score_binary(batch)
                binary_results.extend(results)
                
                if show_progress:
                    pbar.update(len(batch))
            
            if show_progress:
                pbar.close()
            
            # Extract labels and scores
            df['vpi_binary_label'] = [
                1 if r['label'] == 'LABEL_1' else 0 
                for r in binary_results
            ]
            df['vpi_binary_confidence'] = [r['score'] for r in binary_results]
        
        # Regression scoring
        if include_regression and self.regression_model_raw is not None:
            if show_progress:
                print("  Running regression scoring...")
            
            df['vpi_intensity_score'] = self.score_regression(texts)
        
        return df


# =============================================================================
# DATABASE OPERATIONS
# =============================================================================

def get_db_comment_count(db_path: str) -> int:
    """Get total number of comments in a database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT COUNT(*) FROM comments")
    count = cursor.fetchone()[0]
    conn.close()
    return count


def read_comments_from_db(
    db_path: str,
    chunk_size: int = DB_CHUNK_SIZE,
    id_col: str = DEFAULT_ID_COL,
    text_col: str = DEFAULT_TEXT_COL
) -> Generator[pd.DataFrame, None, None]:
    """
    Read comments from SQLite database in chunks.
    
    Args:
        db_path: Path to SQLite database
        chunk_size: Number of rows per chunk
        id_col: Column name for comment ID
        text_col: Column name for comment text
        
    Yields:
        DataFrames with comment_id and text_original columns
    """
    conn = sqlite3.connect(db_path)
    
    query = f"SELECT {id_col}, {text_col} FROM comments"
    
    for chunk in pd.read_sql_query(query, conn, chunksize=chunk_size):
        yield chunk
    
    conn.close()


def find_databases(db_dir: str, pattern: str = DEFAULT_DB_PATTERN) -> List[str]:
    """
    Find all SQLite databases matching pattern in directory.
    
    Args:
        db_dir: Directory to search
        pattern: Glob pattern for database files
        
    Returns:
        List of database file paths
    """
    search_pattern = os.path.join(db_dir, pattern)
    databases = sorted(glob.glob(search_pattern))
    return databases


# =============================================================================
# PROCESSING FUNCTIONS
# =============================================================================

def process_databases(
    scorer: ViolenceScorer,
    databases: List[str],
    output_file: str,
    chunk_size: int = DB_CHUNK_SIZE,
    text_col: str = DEFAULT_TEXT_COL,
    id_col: str = DEFAULT_ID_COL
) -> dict:
    """
    Process multiple SQLite databases and output scored comments.
    
    Args:
        scorer: ViolenceScorer instance
        databases: List of database file paths
        output_file: Path to output CSV file
        chunk_size: Rows to process at a time
        text_col: Column name for comment text
        id_col: Column name for comment ID
        
    Returns:
        Dictionary with processing summary
    """
    summary = {
        'databases_processed': 0,
        'total_comments': 0,
        'comments_per_db': {},
        'errors': []
    }
    
    first_write = True
    
    # Count total comments across all databases
    print("\n--- Counting comments ---")
    total_comments = 0
    for db_path in databases:
        db_name = os.path.basename(db_path)
        count = get_db_comment_count(db_path)
        summary['comments_per_db'][db_name] = count
        total_comments += count
        print(f"  {db_name}: {count:,} comments")
    
    print(f"\nTotal comments to process: {total_comments:,}")
    
    # Process each database
    print("\n--- Processing databases ---")
    
    with tqdm(total=total_comments, desc="Scoring comments") as pbar:
        for db_path in databases:
            db_name = os.path.basename(db_path)
            print(f"\n  Processing: {db_name}")
            
            try:
                # Process in chunks
                for chunk_df in read_comments_from_db(
                    db_path, 
                    chunk_size=chunk_size,
                    id_col=id_col,
                    text_col=text_col
                ):
                    # Score the chunk
                    scored_chunk = scorer.score_dataframe(
                        chunk_df,
                        text_col=text_col,
                        show_progress=False
                    )
                    
                    # Add source database column
                    scored_chunk['source_db'] = db_name
                    
                    # Write to output
                    if first_write:
                        scored_chunk.to_csv(output_file, index=False, mode='w')
                        first_write = False
                    else:
                        scored_chunk.to_csv(output_file, index=False, mode='a', header=False)
                    
                    # Update progress
                    pbar.update(len(chunk_df))
                    summary['total_comments'] += len(chunk_df)
                
                summary['databases_processed'] += 1
                
            except Exception as e:
                print(f"  ERROR: {e}")
                summary['errors'].append({'database': db_path, 'error': str(e)})
    
    return summary


def process_to_sqlite(
    scorer: ViolenceScorer,
    databases: List[str],
    output_db: str,
    chunk_size: int = DB_CHUNK_SIZE,
    text_col: str = DEFAULT_TEXT_COL,
    id_col: str = DEFAULT_ID_COL
) -> dict:
    """
    Process databases and write scored results to a new SQLite database.
    
    Args:
        scorer: ViolenceScorer instance
        databases: List of database file paths
        output_db: Path to output SQLite database
        chunk_size: Rows to process at a time
        text_col: Column name for comment text
        id_col: Column name for comment ID
        
    Returns:
        Dictionary with processing summary
    """
    summary = {
        'databases_processed': 0,
        'total_comments': 0,
        'comments_per_db': {},
        'errors': []
    }
    
    # Create output database and table
    out_conn = sqlite3.connect(output_db)
    
    # Create table (will be created on first write with if_exists='replace')
    first_write = True
    
    # Count total comments
    print("\n--- Counting comments ---")
    total_comments = 0
    for db_path in databases:
        db_name = os.path.basename(db_path)
        count = get_db_comment_count(db_path)
        summary['comments_per_db'][db_name] = count
        total_comments += count
        print(f"  {db_name}: {count:,} comments")
    
    print(f"\nTotal comments to process: {total_comments:,}")
    
    # Process each database
    print("\n--- Processing databases ---")
    
    with tqdm(total=total_comments, desc="Scoring comments") as pbar:
        for db_path in databases:
            db_name = os.path.basename(db_path)
            print(f"\n  Processing: {db_name}")
            
            try:
                for chunk_df in read_comments_from_db(
                    db_path,
                    chunk_size=chunk_size,
                    id_col=id_col,
                    text_col=text_col
                ):
                    # Score the chunk
                    scored_chunk = scorer.score_dataframe(
                        chunk_df,
                        text_col=text_col,
                        show_progress=False
                    )
                    
                    # Add source database column
                    scored_chunk['source_db'] = db_name
                    
                    # Write to SQLite
                    if first_write:
                        scored_chunk.to_sql(
                            'vpi_scores', 
                            out_conn, 
                            index=False, 
                            if_exists='replace'
                        )
                        first_write = False
                    else:
                        scored_chunk.to_sql(
                            'vpi_scores', 
                            out_conn, 
                            index=False, 
                            if_exists='append'
                        )
                    
                    pbar.update(len(chunk_df))
                    summary['total_comments'] += len(chunk_df)
                
                summary['databases_processed'] += 1
                
            except Exception as e:
                print(f"  ERROR: {e}")
                summary['errors'].append({'database': db_path, 'error': str(e)})
    
    # Create index on comment_id for faster lookups
    print("\n--- Creating database index ---")
    out_conn.execute("CREATE INDEX IF NOT EXISTS idx_comment_id ON vpi_scores(comment_id)")
    out_conn.commit()
    out_conn.close()
    
    return summary


# =============================================================================
# SUMMARY STATISTICS
# =============================================================================

def print_scoring_summary(output_file: str, sample_size: int = 100000):
    """Print summary statistics from output file."""
    
    print("\n" + "="*60)
    print("SCORING SUMMARY")
    print("="*60)
    
    # Read sample for statistics
    if output_file.endswith('.db'):
        conn = sqlite3.connect(output_file)
        df = pd.read_sql_query(
            f"SELECT * FROM vpi_scores LIMIT {sample_size}", 
            conn
        )
        total_count = pd.read_sql_query(
            "SELECT COUNT(*) as cnt FROM vpi_scores", 
            conn
        )['cnt'].iloc[0]
        conn.close()
    else:
        # CSV - read first N rows
        df = pd.read_csv(output_file, nrows=sample_size)
        # Count total rows
        with open(output_file, 'r') as f:
            total_count = sum(1 for _ in f) - 1  # Subtract header
    
    print(f"\nTotal comments scored: {total_count:,}")
    print(f"(Statistics based on sample of {len(df):,} comments)")
    
    if 'vpi_binary_label' in df.columns:
        n_violence = df['vpi_binary_label'].sum()
        n_total = len(df)
        pct_violence = (n_violence / n_total) * 100
        
        print(f"\nBinary Classification:")
        print(f"  Violence detected: {n_violence:,} ({pct_violence:.1f}%)")
        print(f"  No violence: {n_total - n_violence:,} ({100-pct_violence:.1f}%)")
        
        if 'vpi_binary_confidence' in df.columns:
            print(f"  Mean confidence: {df['vpi_binary_confidence'].mean():.3f}")
    
    if 'vpi_intensity_score' in df.columns:
        print(f"\nRegression Scores (0-10):")
        print(f"  Mean: {df['vpi_intensity_score'].mean():.3f}")
        print(f"  Std:  {df['vpi_intensity_score'].std():.3f}")
        print(f"  Min:  {df['vpi_intensity_score'].min():.3f}")
        print(f"  Max:  {df['vpi_intensity_score'].max():.3f}")
        
        # Distribution by bins
        bins = [0, 1, 2, 3, 5, 7, 10]
        labels = ['0-1', '1-2', '2-3', '3-5', '5-7', '7-10']
        df['score_bin'] = pd.cut(
            df['vpi_intensity_score'], 
            bins=bins, 
            labels=labels, 
            include_lowest=True
        )
        
        print(f"\n  Score Distribution:")
        for label in labels:
            count = (df['score_bin'] == label).sum()
            pct = (count / len(df)) * 100
            print(f"    {label}: {count:,} ({pct:.1f}%)")
    
    # Per-database breakdown if available
    if 'source_db' in df.columns:
        print(f"\n  Comments per source database (in sample):")
        for db, count in df['source_db'].value_counts().items():
            print(f"    {db}: {count:,}")


# =============================================================================
# MAIN
# =============================================================================

def main(args):
    """Main inference pipeline."""
    
    print("\n" + "="*60)
    print("VPI INFERENCE PIPELINE (SQLite Mode)")
    print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*60)
    
    # Find databases
    if args.databases:
        databases = args.databases
    elif args.db_dir:
        databases = find_databases(args.db_dir, args.db_pattern)
    else:
        raise ValueError("Must specify either --databases or --db_dir")
    
    if not databases:
        raise FileNotFoundError(
            f"No databases found matching pattern '{args.db_pattern}' in {args.db_dir}"
        )
    
    print(f"\nFound {len(databases)} database(s):")
    for db in databases:
        print(f"  - {db}")
    
    # Initialize scorer
    print("\n--- Loading Models ---")
    scorer = ViolenceScorer(
        model_dir=args.model_dir,
        device=args.device,
        batch_size=args.batch_size
    )
    
    # Process databases
    if args.output_db:
        # Output to SQLite
        summary = process_to_sqlite(
            scorer,
            databases,
            args.output_db,
            chunk_size=args.chunk_size,
            text_col=args.text_col,
            id_col=args.id_col
        )
        output_path = args.output_db
    else:
        # Output to CSV
        summary = process_databases(
            scorer,
            databases,
            args.output_file,
            chunk_size=args.chunk_size,
            text_col=args.text_col,
            id_col=args.id_col
        )
        output_path = args.output_file
    
    # Print summary
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"  Databases processed: {summary['databases_processed']}")
    print(f"  Total comments scored: {summary['total_comments']:,}")
    print(f"  Output saved to: {output_path}")
    
    if summary['errors']:
        print(f"\n  Errors encountered: {len(summary['errors'])}")
        for err in summary['errors']:
            print(f"    - {err['database']}: {err['error']}")
    
    # Print scoring statistics
    print_scoring_summary(output_path)
    
    print(f"\nFinished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Score YouTube comments from SQLite databases using fine-tuned violence classifier"
    )
    
    # Database input options
    db_group = parser.add_mutually_exclusive_group(required=True)
    
    db_group.add_argument(
        "--databases",
        type=str,
        nargs='+',
        help="Paths to SQLite database files"
    )
    
    db_group.add_argument(
        "--db_dir",
        type=str,
        help="Directory containing SQLite database files"
    )
    
    parser.add_argument(
        "--db_pattern",
        type=str,
        default=DEFAULT_DB_PATTERN,
        help=f"Glob pattern for finding databases (default: {DEFAULT_DB_PATTERN})"
    )
    
    # Output options
    output_group = parser.add_mutually_exclusive_group(required=True)
    
    output_group.add_argument(
        "--output_file",
        type=str,
        help="Output CSV file path"
    )
    
    output_group.add_argument(
        "--output_db",
        type=str,
        help="Output SQLite database path"
    )
    
    # Model options
    parser.add_argument(
        "--model_dir",
        type=str,
        required=True,
        help="Directory containing trained models"
    )
    
    parser.add_argument(
        "--device",
        type=str,
        choices=['cuda', 'mps', 'cpu'],
        default=None,
        help="Device to use (default: auto-detect)"
    )
    
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="Batch size for inference (default: 64)"
    )
    
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=DB_CHUNK_SIZE,
        help=f"Database chunk size for processing (default: {DB_CHUNK_SIZE})"
    )
    
    # Data column options
    parser.add_argument(
        "--text_col",
        type=str,
        default=DEFAULT_TEXT_COL,
        help=f"Column name for comment text (default: {DEFAULT_TEXT_COL})"
    )
    
    parser.add_argument(
        "--id_col",
        type=str,
        default=DEFAULT_ID_COL,
        help=f"Column name for comment ID (default: {DEFAULT_ID_COL})"
    )
    
    args = parser.parse_args()
    
    main(args)
